"""
# Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import os
import unittest
from unittest.mock import patch

from fastdeploy.engine.sampling_params import SamplingParams


class TestSamplingParamsVerification(unittest.TestCase):
    """Test case for SamplingParams _verify_args method"""

    def test_logprobs_valid_values(self):
        """Test valid logprobs values"""
        # Test None value (should pass in both modes)
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            params = SamplingParams(logprobs=None)
            params._verify_args()  # Should not raise

        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(logprobs=None)
            params._verify_args()  # Should not raise

        # Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(logprobs=-1)
            params._verify_args()  # Should not raise

        # Test 0 value (should pass in both modes based on actual behavior)
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            params = SamplingParams(logprobs=0)
            params._verify_args()  # Should not raise

        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(logprobs=0)
            params._verify_args()  # Should not raise

        # Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0")
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            params = SamplingParams(logprobs=20)
            params._verify_args()  # Should not raise

    def test_logprobs_invalid_less_than_minus_one(self):
        """Test logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            with self.assertRaises(ValueError) as cm:
                params = SamplingParams(logprobs=-2)
                params._verify_args()

            self.assertIn("logprobs must be a non-negative value or -1", str(cm.exception))
            self.assertIn("got -2", str(cm.exception))

    def test_logprobs_invalid_less_than_zero(self):
        """Test logprobs less than 0 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "0" """
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            with self.assertRaises(ValueError) as cm:
                params = SamplingParams(logprobs=-1)
                params._verify_args()

            self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", str(cm.exception))

    def test_logprobs_greater_than_20_with_v1_disabled(self):
        """Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            with self.assertRaises(ValueError) as cm:
                params = SamplingParams(logprobs=21)
                params._verify_args()

            self.assertEqual("Invalid value for 'top_logprobs': must be between 0 and 20.", str(cm.exception))

    def test_logprobs_greater_than_20_with_v1_enabled(self):
        """Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            # Should not raise when v1 is enabled
            params = SamplingParams(logprobs=21)
            params._verify_args()  # Should not raise

            # Test even larger values when v1 is enabled
            params = SamplingParams(logprobs=100)
            params._verify_args()  # Should not raise

    def test_prompt_logprobs_valid_values(self):
        """Test valid prompt_logprobs values"""
        # Test None value (should pass in both modes based on actual behavior)
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            params = SamplingParams(prompt_logprobs=None)
            params._verify_args()  # Should not raise

        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(prompt_logprobs=None)
            params._verify_args()  # Should not raise

        # Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(prompt_logprobs=-1)
            params._verify_args()  # Should not raise

        # Test 0 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(prompt_logprobs=0)
            params._verify_args()  # Should not raise

        # Test positive values (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(prompt_logprobs=10)
            params._verify_args()  # Should not raise

    def test_prompt_logprobs_invalid_less_than_minus_one(self):
        """Test prompt_logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            with self.assertRaises(ValueError) as cm:
                params = SamplingParams(prompt_logprobs=-2)
                params._verify_args()

            self.assertIn("prompt_logprobs a must be non-negative value or -1", str(cm.exception))
            self.assertIn("got -2", str(cm.exception))

    def test_combined_logprobs_and_prompt_logprobs(self):
        """Test both logprobs and prompt_logprobs together"""
        # Test valid combination when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(logprobs=5, prompt_logprobs=3)
            params._verify_args()  # Should not raise

        # Test invalid logprobs with valid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            with self.assertRaises(ValueError):
                params = SamplingParams(logprobs=-2, prompt_logprobs=5)
                params._verify_args()

        # Test valid logprobs with invalid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            with self.assertRaises(ValueError):
                params = SamplingParams(logprobs=5, prompt_logprobs=-2)
                params._verify_args()

        # Test prompt_logprobs not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            with self.assertRaises(ValueError) as cm:
                params = SamplingParams(logprobs=5, prompt_logprobs=3)
                params._verify_args()
            self.assertIn(
                "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
            )

    def test_logprobs_boundary_values(self):
        """Test boundary values for logprobs"""
        # Test just below limit with v1 disabled
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            params = SamplingParams(logprobs=20)
            params._verify_args()  # Should pass

        # Test just above limit with v1 disabled
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            with self.assertRaises(ValueError):
                params = SamplingParams(logprobs=21)
                params._verify_args()

    def test_prompt_logprobs_boundary_values(self):
        """Test boundary values for prompt_logprobs"""
        # Test boundary value -1 (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(prompt_logprobs=-1)
            params._verify_args()  # Should pass

        # Test boundary value just below -1 (should fail when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            with self.assertRaises(ValueError):
                params = SamplingParams(prompt_logprobs=-2)
                params._verify_args()

    def test_environment_variable_handling(self):
        """Test different environment variable values"""
        # Test FD_USE_GET_SAVE_OUTPUT_V1 = "0" (default behavior)
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            with self.assertRaises(ValueError):
                params = SamplingParams(logprobs=21)
                params._verify_args()

        # Test FD_USE_GET_SAVE_OUTPUT_V1 = "1" (relaxed behavior)
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(logprobs=21)
            params._verify_args()  # Should pass

        # Test FD_USE_GET_SAVE_OUTPUT_V1 not set (default to "0")
        if "FD_USE_GET_SAVE_OUTPUT_V1" in os.environ:
            original_value = os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
            del os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
        else:
            original_value = None

        try:
            with self.assertRaises(ValueError):
                params = SamplingParams(logprobs=21)
                params._verify_args()
        finally:
            if original_value is not None:
                os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value

        # Test prompt_logprobs behavior with different environment variables
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            with self.assertRaises(ValueError) as cm:
                params = SamplingParams(prompt_logprobs=5)
                params._verify_args()
            self.assertIn(
                "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
            )

        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(prompt_logprobs=5)
            params._verify_args()  # Should pass

    def test_error_message_formatting(self):
        """Test that error messages are properly formatted"""
        # Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            with self.assertRaises(ValueError) as cm:
                params = SamplingParams(logprobs=-5)
                params._verify_args()

            error_msg = str(cm.exception)
            self.assertIn("logprobs must be a non-negative value or -1", error_msg)
            self.assertIn("got -5", error_msg)

        # Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            with self.assertRaises(ValueError) as cm:
                params = SamplingParams(logprobs=-1)
                params._verify_args()

            error_msg = str(cm.exception)
            self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", error_msg)

        # Test prompt_logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            with self.assertRaises(ValueError) as cm:
                params = SamplingParams(prompt_logprobs=-10)
                params._verify_args()

            error_msg = str(cm.exception)
            self.assertIn("prompt_logprobs a must be non-negative value or -1", error_msg)
            self.assertIn("got -10", error_msg)

        # Test prompt_logprobs not supported error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            with self.assertRaises(ValueError) as cm:
                params = SamplingParams(prompt_logprobs=5)
                params._verify_args()

            error_msg = str(cm.exception)
            self.assertIn("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", error_msg)

    def test_post_init_calls_verify_args(self):
        """Test that __post_init__ calls _verify_args"""
        # This should call _verify_args internally when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(logprobs=5, prompt_logprobs=3)

            # The params should be successfully created without errors
            self.assertEqual(params.logprobs, 5)
            self.assertEqual(params.prompt_logprobs, 3)

            # Test that invalid values are caught during initialization
            with self.assertRaises(ValueError):
                SamplingParams(logprobs=-2)

            with self.assertRaises(ValueError):
                SamplingParams(prompt_logprobs=-2)

        # Test that prompt_logprobs is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            with self.assertRaises(ValueError):
                SamplingParams(prompt_logprobs=3)

            # Test that logprobs < 0 is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
            with self.assertRaises(ValueError):
                SamplingParams(logprobs=-1)

    def test_logprobs_with_other_parameters(self):
        """Test logprobs validation with other sampling parameters"""
        # Test with temperature when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(logprobs=5, temperature=0.8)
            params._verify_args()  # Should pass

        # Test with top_p when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(logprobs=5, top_p=0.9)
            params._verify_args()  # Should pass

        # Test with all parameters when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
            params = SamplingParams(
                logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
            )
            params._verify_args()  # Should pass

        # Test that prompt_logprobs fails when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
        with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
            with self.assertRaises(ValueError):
                params = SamplingParams(
                    logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
                )
                params._verify_args()


if __name__ == "__main__":
    unittest.main()
